library(yaml)
library(dplyr)
library(tidyr)
library(purrr)
library(stringr)
library(ggplot2)
library(ggcorrplot)
library(scales)

source("helper_funs.R")
source("auxiliary_regression_funcs.R")
configs <- read_yaml("configs.yml")

args = commandArgs(trailingOnly = T)
if (length(args)==3){
    args <- as.list(args)
    names(args) <- c("burn","iter","thin")
} else {
    print("Please provide burn, iter and thin parameters")
}

in_path <- file.path("..")
out_path <- file.path("..", "results")
#dir.create(out_path, showWarnings = F)

#-------------------------------------------------------------------------------
# SUPPLEMENT FIGURE 4: 'Simpler approach'
# response: Difference in log odd between percent voting with party on proc and final

CONGS <- 93:113
res_with_party_bayes <- tibble(cong=CONGS) %>%
  mutate(res=map(cong, run_regression_percent_with_party, 
                 in_path=in_path,
                 penalized=F, bayes=T)) %>%
  unnest(res)

ORDER <- c(sort(configs$covariate_cats$party),configs$covariate_cats$legislator,configs$covariate_cats$constituancy)

#hacky - flip signs of coefs so outcome is log-oods(proc)-log-odds(final)
res_with_party_bayes <-
  res_with_party_bayes %>% mutate(coef=-1*coef)

plot_redblue_one(res_with_party_bayes, ORDER)
ggsave(file.path(out_path, "fig4_supplement_alt_regression_bayes.jpeg"),
       height=3.5, width=6)


#-------------------------------------------------------------------------------
## SUPPLEMENT FIGURE 1: Bridge probs vs covariate adjusted priors

in_path <- file.path("..","results","stage0")
out_path <- file.path("..","results")
in_file_suffix <- paste0(args$burn,"_",args$iter,"_",args$thin)
which_cong = 107

# get coefficients
output = read_one(which_cong, in_path, in_file_suffix)
eta = output$stage2[[1]][[1]]$eta

# get X
cong_str = paste0("H",str_pad(which_cong, 3, pad="0"))
input=readRDS(file.path("..", "data",
                        paste0("inputs_", cong_str,
                               ".RDS")))
x_raw <- input$covariates
x <- cbind(1,x_raw[,configs$covariates]) %>%
  mutate_all(standardize) %>%
  as.matrix()
party = input$members$party

# Posterior mean of the covariate-adjusted prior probability of being a bridge
linpred <- x%*%t(eta)
covadjusted.priorprob <- 1/(1+exp(-linpred))
covadjusted.priorprob.mean = apply(covadjusted.priorprob,1,mean)

# Posterior mean of the probability of being a bridge
posteriorprob.mean = 1-output$res[[1]]$p_change

# Graph
colscale = c("blue","red")
pchscale = c(16,17)
jpeg(file=paste0(out_path,"/fig1_supplement_priortoposteriorbridgingprob.jpeg"))
par(mar=c(4,4,1,1)+0.15)
plot(covadjusted.priorprob.mean, posteriorprob.mean, xlim=c(0,1), ylim=c(0,1),
     xlab="Prior bridging probability (covariate adjusted)", ylab="Posterior bridging probability", cex.lab=1.35,
     col=colscale[as.numeric(as.factor(party))], pch=pchscale[as.numeric(as.factor(party))])
dev.off()


#-------------------------------------------------------------------------------
## SUPPLEMENT FIGURE 2: Compare ASF between models with and without covariates

CONGS = 93:113
in_file_suffix <- rep(paste0(args$burn,"_",args$iter,"_",args$thin), length(CONGS))
in_path_stage0 <- rep(file.path("..","results","stage0"),length(CONGS))
in_path_stage1 <- rep(file.path("..","results","stage1"),length(CONGS))
out_path <- file.path("..","results")
party_info <- readRDS(file.path("..","data","party_info.RDS"))



## ASF for model with covariates
res_stage0 <-
  tibble(cong=CONGS) %>%
  mutate(res=map(cong, ~read_one(.x,
                                 in_path_stage0[which(CONGS==.x)],
                                 in_file_suffix[which(CONGS==.x)]))) %>%
  unnest(c(res)) %>%
  dplyr::select(cong, summy) %>% ####
  unnest(summy) %>% ###
  mutate(label="Covariates") ####

## ASF for model without covariates
res_stage1 <-
  tibble(cong=CONGS) %>%
  mutate(res=map(cong, ~read_one_stage1(.x,
                                        in_path_stage1[which(CONGS==.x)],
                                        in_file_suffix[which(CONGS==.x)]))) %>%
  unnest(c(res)) %>%
  dplyr::select(cong, summy) %>%
  unnest(summy) %>%
  mutate(label="No Covariates")


## Combine into a single tibble
summy = rbind(res_stage0, res_stage1)

## Generate plot
summy %>%
  filter(type=="All" | is.na(type)) %>%
  ggplot(aes(x=cong, y=ASF, ymin=ASF_L95, ymax=ASF_U95, linetype=label)) +
  geom_linerange(position = position_dodge(.5),
                 show.legend = F) +
  geom_point(position = position_dodge(.5), size=1) +
  geom_rect(data=party_info, ymin=0, ymax=1,
            aes(xmin=start-.5, xmax=end+.5),
            inherit.aes = F, alpha=.2) +
  xlab("House") + ylab("Bridging frequency") +
  scale_linetype_discrete("") +
  theme_bw()

ggsave(file.path(out_path, "fig2_supplement_compare_ASF.jpeg"),
       height = 2.5, width=6)


#-------------------------------------------------------------------------------
## SUPPLEMENT FIGURE 3: Compare identity of bridges (a bridge defined as those having bridging
#           probability greater than 'threshold')

threshold = 0.5
label = c("onlyjoint","onlynocovar","bothb")

res_stage0 <-
  tibble(cong=CONGS) %>%
  mutate(res=map(cong, ~read_one(.x,
                                 in_path_stage0[which(CONGS==.x)],
                                 in_file_suffix[which(CONGS==.x)]))) %>%
  unnest(res) %>%
  dplyr::select(cong, res) %>%
  unnest(res) %>%
  mutate(bridge_prob=1-p_change, label="Covariates") %>%
  dplyr::select(cong, legislator, party, label, bridge_prob)

res_stage1 <-
  tibble(cong=CONGS) %>%
  mutate(res=map(cong, ~read_one_stage1(.x,
                                 in_path_stage1[which(CONGS==.x)],
                                 in_file_suffix[which(CONGS==.x)]))) %>%
  unnest(res) %>%
  dplyr::select(cong, res) %>%
  unnest(res) %>%
  mutate(bridge_prob=1-p_change, label="No Covariates") %>%
  dplyr::select(cong, legislator, party, label, bridge_prob)


matchcounts = crossing(cong=CONGS, label, p=NA)
for(i in CONGS){
  aa <- res_stage0 %>% filter(cong==i)
  bb <- res_stage1 %>% filter(cong==i)
  matchcounts[as.vector(matchcounts[,1]==i & matchcounts[,2]=="onlyjoint"),3] = mean(aa$bridge_prob > threshold & bb$bridge_prob < threshold)
  matchcounts[as.vector(matchcounts[,1]==i & matchcounts[,2]=="onlynocovar"),3] = mean(aa$bridge_prob < threshold & bb$bridge_prob > threshold)
  matchcounts[as.vector(matchcounts[,1]==i & matchcounts[,2]=="bothb"),3] = 1 - mean(aa$bridge_prob < threshold & bb$bridge_prob > threshold) - mean(aa$bridge_prob > threshold & bb$bridge_prob < threshold)
}

matchcounts %>%
  ggplot(aes(x=cong, y=p,
             color=label, linetype=label)) +
  geom_rect(data=party_info, ymin=-.05, ymax=1.05,
            aes(xmin=start-.5, xmax=end+.5),
            inherit.aes = F, alpha=.1) +
  geom_line(show.legend = F) +
  xlab("House") + ylab("Percent of legislators") +
  scale_y_continuous(limits=c(0,1)) +
  theme_bw()

ggsave(file.path(out_path, "fig3_supplement_compare_bridges.jpeg"),
       height = 2.5, width=6)
